pytorch预训练模型
1.加载预训练模型:
只加载模型,不加载预训练参数:resnet18 = models.resnet18(pretrained=False)
print resnet18 打印模型结构
resnet18.load_state_dict(torch.load('resnet18-5c106cde.pth'))加载预先下载好的预训练参数到resnet18
print resnet18 打印的还是模型结构
note: cnn = resnet18.load_state_dict(torch.load('resnet18-5c106cde.pth'))是错误的,这样cnn将是nonetype
pre_dict = resnet18.state_dict()按键值对将模型参数加载到pre_dict
print for k, v in pre_dict.items(): 打印模型参数
for k, v in pre_dict.items():
print k
打印模型每层命名
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
note:model是自己定义好的模型,将pretrained_dict和model_dict中命名一致的层加入pretrained_dict(包括参数)
加载模型和预训练参数:resnet34 = models.resnet34(pretrained=True)
reference:
1.
http://blog.csdn.net/VictoriaW/article/details/72821329
2.
vgg16 = models.vgg16(pretrained=True)
pretrained_dict = vgg16.state_dict()
model_dict = model.state_dict()
# 1. filter out unnecessary keys
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
# 2. overwrite entries in the existing state dict
model_dict.update(pretrained_dict)
# 3. load the new state dict
model.load_state_dict(model_dict)